{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "12ada6c3",
   "metadata": {},
   "source": [
    "(tune-lightgbm-example)=\n",
    "\n",
    "# Using LightGBM with Tune\n",
    "\n",
    "<a id=\"try-anyscale-quickstart-ray-tune-lightgbm_example\" href=\"https://console.anyscale.com/register/ha?render_flow=ray&utm_source=ray_docs&utm_medium=docs&utm_campaign=ray-tune-lightgbm_example\">\n",
    "    <img src=\"../../_static/img/run-on-anyscale.svg\" alt=\"try-anyscale-quickstart\">\n",
    "</a>\n",
    "<br></br>\n",
    "\n",
    "```{image} /images/lightgbm_logo.png\n",
    ":align: center\n",
    ":alt: LightGBM Logo\n",
    ":height: 120px\n",
    ":target: https://lightgbm.readthedocs.io\n",
    "```\n",
    "\n",
    "```{contents}\n",
    ":backlinks: none\n",
    ":local: true\n",
    "```\n",
    "\n",
    "This tutorial shows how to use Ray Tune to optimize hyperparameters for a LightGBM model. We'll use the breast cancer classification dataset from scikit-learn to demonstrate how to:\n",
    "\n",
    "1. Set up a LightGBM training function with Ray Tune\n",
    "2. Configure hyperparameter search spaces\n",
    "3. Use the ASHA scheduler for efficient hyperparameter tuning\n",
    "4. Report and checkpoint training progress\n",
    "\n",
    "## Installation\n",
    "\n",
    "First, let's install the required dependencies:\n",
    "\n",
    "```bash\n",
    "pip install \"ray[tune]\" lightgbm scikit-learn numpy\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b4c3f1e1",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"tuneStatus\">\n",
       "  <div style=\"display: flex;flex-direction: row\">\n",
       "    <div style=\"display: flex;flex-direction: column;\">\n",
       "      <h3>Tune Status</h3>\n",
       "      <table>\n",
       "<tbody>\n",
       "<tr><td>Current time:</td><td>2025-02-18 17:33:55</td></tr>\n",
       "<tr><td>Running for: </td><td>00:00:01.27        </td></tr>\n",
       "<tr><td>Memory:      </td><td>25.8/36.0 GiB      </td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    </div>\n",
       "    <div class=\"vDivider\"></div>\n",
       "    <div class=\"systemInfo\">\n",
       "      <h3>System Info</h3>\n",
       "      Using AsyncHyperBand: num_stopped=4<br>Bracket: Iter 64.000: -0.1048951048951049 | Iter 16.000: -0.3076923076923077 | Iter 4.000: -0.3076923076923077 | Iter 1.000: -0.32342657342657344<br>Logical resource usage: 1.0/12 CPUs, 0/0 GPUs\n",
       "    </div>\n",
       "    \n",
       "  </div>\n",
       "  <div class=\"hDivider\"></div>\n",
       "  <div class=\"trialStatus\">\n",
       "    <h3>Trial Status</h3>\n",
       "    <table>\n",
       "<thead>\n",
       "<tr><th>Trial name                     </th><th>status    </th><th>loc            </th><th>boosting_type  </th><th style=\"text-align: right;\">  learning_rate</th><th style=\"text-align: right;\">  num_leaves</th><th style=\"text-align: right;\">  iter</th><th style=\"text-align: right;\">  total time (s)</th><th style=\"text-align: right;\">  binary_error</th><th style=\"text-align: right;\">  binary_logloss</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>train_breast_cancer_945ea_00000</td><td>TERMINATED</td><td>127.0.0.1:26189</td><td>gbdt           </td><td style=\"text-align: right;\">    0.00372129 </td><td style=\"text-align: right;\">         622</td><td style=\"text-align: right;\">   100</td><td style=\"text-align: right;\">      0.0507247 </td><td style=\"text-align: right;\">      0.104895</td><td style=\"text-align: right;\">        0.45487 </td></tr>\n",
       "<tr><td>train_breast_cancer_945ea_00001</td><td>TERMINATED</td><td>127.0.0.1:26191</td><td>dart           </td><td style=\"text-align: right;\">    0.0065691  </td><td style=\"text-align: right;\">         998</td><td style=\"text-align: right;\">     1</td><td style=\"text-align: right;\">      0.013751  </td><td style=\"text-align: right;\">      0.391608</td><td style=\"text-align: right;\">        0.665636</td></tr>\n",
       "<tr><td>train_breast_cancer_945ea_00002</td><td>TERMINATED</td><td>127.0.0.1:26190</td><td>gbdt           </td><td style=\"text-align: right;\">    1.17012e-07</td><td style=\"text-align: right;\">         995</td><td style=\"text-align: right;\">     1</td><td style=\"text-align: right;\">      0.0146749 </td><td style=\"text-align: right;\">      0.412587</td><td style=\"text-align: right;\">        0.68387 </td></tr>\n",
       "<tr><td>train_breast_cancer_945ea_00003</td><td>TERMINATED</td><td>127.0.0.1:26192</td><td>dart           </td><td style=\"text-align: right;\">    0.000194983</td><td style=\"text-align: right;\">          53</td><td style=\"text-align: right;\">     1</td><td style=\"text-align: right;\">      0.00605583</td><td style=\"text-align: right;\">      0.328671</td><td style=\"text-align: right;\">        0.6405  </td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "  </div>\n",
       "</div>\n",
       "<style>\n",
       ".tuneStatus {\n",
       "  color: var(--jp-ui-font-color1);\n",
       "}\n",
       ".tuneStatus .systemInfo {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus td {\n",
       "  white-space: nowrap;\n",
       "}\n",
       ".tuneStatus .trialStatus {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus h3 {\n",
       "  font-weight: bold;\n",
       "}\n",
       ".tuneStatus .hDivider {\n",
       "  border-bottom-width: var(--jp-border-width);\n",
       "  border-bottom-color: var(--jp-border-color0);\n",
       "  border-bottom-style: solid;\n",
       "}\n",
       ".tuneStatus .vDivider {\n",
       "  border-left-width: var(--jp-border-width);\n",
       "  border-left-color: var(--jp-border-color0);\n",
       "  border-left-style: solid;\n",
       "  margin: 0.5em 1em 0.5em 1em;\n",
       "}\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-18 17:33:55,300\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/rdecal/ray_results/train_breast_cancer_2025-02-18_17-33-54' in 0.0035s.\n",
      "2025-02-18 17:33:55,302\tINFO tune.py:1041 -- Total run time: 1.28 seconds (1.27 seconds for the tuning loop).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best hyperparameters found were: {'objective': 'binary', 'metric': ['binary_error', 'binary_logloss'], 'verbose': -1, 'boosting_type': 'gbdt', 'num_leaves': 622, 'learning_rate': 0.003721286118355498}\n"
     ]
    }
   ],
   "source": [
    "import lightgbm as lgb\n",
    "import numpy as np\n",
    "import sklearn.datasets\n",
    "import sklearn.metrics\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from ray import tune\n",
    "from ray.tune.schedulers import ASHAScheduler\n",
    "from ray.tune.integration.lightgbm import TuneReportCheckpointCallback\n",
    "\n",
    "\n",
    "def train_breast_cancer(config):\n",
    "\n",
    "    data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)\n",
    "    train_x, test_x, train_y, test_y = train_test_split(data, target, test_size=0.25)\n",
    "    train_set = lgb.Dataset(train_x, label=train_y)\n",
    "    test_set = lgb.Dataset(test_x, label=test_y)\n",
    "    gbm = lgb.train(\n",
    "        config,\n",
    "        train_set,\n",
    "        valid_sets=[test_set],\n",
    "        valid_names=[\"eval\"],\n",
    "        callbacks=[\n",
    "            TuneReportCheckpointCallback(\n",
    "                {\n",
    "                    \"binary_error\": \"eval-binary_error\",\n",
    "                    \"binary_logloss\": \"eval-binary_logloss\",\n",
    "                }\n",
    "            )\n",
    "        ],\n",
    "    )\n",
    "    preds = gbm.predict(test_x)\n",
    "    pred_labels = np.rint(preds)\n",
    "    tune.report(\n",
    "        {\n",
    "            \"mean_accuracy\": sklearn.metrics.accuracy_score(test_y, pred_labels),\n",
    "            \"done\": True,\n",
    "        }\n",
    "    )\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    config = {\n",
    "        \"objective\": \"binary\",\n",
    "        \"metric\": [\"binary_error\", \"binary_logloss\"],\n",
    "        \"verbose\": -1,\n",
    "        \"boosting_type\": tune.grid_search([\"gbdt\", \"dart\"]),\n",
    "        \"num_leaves\": tune.randint(10, 1000),\n",
    "        \"learning_rate\": tune.loguniform(1e-8, 1e-1),\n",
    "    }\n",
    "\n",
    "    tuner = tune.Tuner(\n",
    "        train_breast_cancer,\n",
    "        tune_config=tune.TuneConfig(\n",
    "            metric=\"binary_error\",\n",
    "            mode=\"min\",\n",
    "            scheduler=ASHAScheduler(),\n",
    "            num_samples=2,\n",
    "        ),\n",
    "        param_space=config,\n",
    "    )\n",
    "    results = tuner.fit()\n",
    "\n",
    "    print(f\"Best hyperparameters found were: {results.get_best_result().config}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01d74c39",
   "metadata": {},
   "source": [
    "This should give an output like:\n",
    "\n",
    "```python\n",
    "Best hyperparameters found were: {'objective': 'binary', 'metric': ['binary_error', 'binary_logloss'], 'verbose': -1, 'boosting_type': 'gbdt', 'num_leaves': 622, 'learning_rate': 0.003721286118355498}\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lightgbm_example",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
